import os
import cv2
import numpy as np
import pandas as pd

imagePath = os.path.abspath("images")
avgHashTable = os.path.abspath("avgHashTable.csv")
differenceHashTable = os.path.abspath("differenceHashTable.csv")
dctHashTable = os.path.abspath("dctHashTable.csv")
size = (32, 32)

def readImagePathList(imageDir):
	imagePaths = []
	for filename in os.listdir(imageDir):
		if filename.rfind(".png") != -1:
			imagePaths.append(os.path.join(imageDir, filename))
	return imagePaths

def readImage(imgPath):
	img = cv2.imdecode(np.fromfile(imgPath, dtype=np.uint8), -1)
	return img

def avgHash(imgPath):
	# 读取图片
	img = cv2.imdecode(np.fromfile(imgPath, dtype=np.uint8), -1)
	# 转为设定好的尺寸(32, 32)
	img = cv2.resize(img, size)
	# 转为灰度图
	img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
	# 获取图片尺寸大小
	height, width = img.shape
	# 计算图片像素平均值
	avg = 0
	for i in range(height):
		for j in range(width):
			avg += img[i][j]
	avg /= (height * width)
	# 二值化, 计算均值hash
	hkey = ""
	for i in range(height):
		for j in range(width):
			if img[i][j] <= avg:
				hkey = hkey + "0"
			else:
				hkey = hkey + "1"
	return hkey

def differenceHash(imgPath):
	# 读取图片
	img = cv2.imdecode(np.fromfile(imgPath, dtype=np.uint8), -1)
	# 转为设定好的尺寸(32, 32+1)
	img = cv2.resize(img, (size[0], size[1]+1))
	# 转为灰度图
	img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
	# 获取图片尺寸大小
	height, width = img.shape
	# 计算差值hash
	hkey = ""
	for i in range(height):
		for j in range(width-1):
			if img[i][j] <= img[i][j+1]:
				hkey = hkey + "0"
			else:
				hkey = hkey + "1"
	return hkey

def dctHash(imgPath):
	# 读取图片
	img = cv2.imdecode(np.fromfile(imgPath, dtype=np.uint8), -1)
	# 转为设定好的尺寸(32, 32)
	img = cv2.resize(img, size)
	# 转为灰度图
	img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
	# dct变换
	img = cv2.dct(np.array(img, np.float32))
	# 获取图片尺寸大小
	height, width = img.shape
	height //= 4
	width //= 4
	# 计算图片像素平均值
	avg = 0
	for i in range(height):
		for j in range(width):
			avg += img[i][j]
	avg /= (height * width)
	# 二值化, 计算均值hash
	hkey = ""
	for i in range(height):
		for j in range(width):
			if img[i][j] <= avg:
				hkey = hkey + "0"
			else:
				hkey = hkey + "1"
	return hkey

def hamming(s1, s2):
	if len(s1) != len(s2):
		raise Exception("{0} len({1}) != {2} len({3})".format(s1, len(s1), s2, len(s2)))
	slen = (len(s1) + len(s2)) / 2.0
	return 1.0 - sum([ch1 != ch2 for ch1, ch2 in zip(s1, s2)]) / slen

def cosin(v1, v2):
	num = float(np.dot(v1, v2))
	denom = np.linalg.norm(v1) * np.linalg.norm(v2)
	return 0.5 + 0.5 * (num / denom) if denom != 0 else 0

def create():
	return pd.DataFrame(columns=["Name", "Path", "HashCode"])

def open(path):
	df = pd.read_csv(path)
	return df

def save(df, path):
	df.to_csv(path, encoding="utf-8")

def BuildingTable(tablePath, callback):
	df = create()
	imagePaths = readImagePathList(imagePath)
	for imageName in imagePaths:
		data = {}
		data["Name"] = imageName[imageName.rfind("/") + 1 : imageName.rfind(".png")]
		data["Path"] = imageName
		data["HashCode"] = callback(imageName)
		df = df.append(pd.Series(data), ignore_index=True)
	save(df, tablePath)

def feature(imgPath):
	# 读取图片
	img = cv2.imdecode(np.fromfile(imgPath, dtype=np.uint8), -1)
	# 转为设定好的尺寸(32, 32)
	img = cv2.resize(img, size)
	# 转为灰度图
	img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
	# 获取图片尺寸大小
	height, width = img.shape
	# 计算图片像素平均值
	feature = []
	for i in range(height):
		avg = 0
		cnt = 0
		for j in range(width):
			if img[i][j] == 255:
				continue
			avg += img[i][j]
			cnt += 1
		feature.append(avg / (cnt if cnt != 0 else 1))
	return np.array(feature) / 255

def BuildingTables():
	BuildingTable(avgHashTable, avgHash)
	BuildingTable(differenceHashTable, differenceHash)
	BuildingTable(dctHashTable, dctHash)

def main():
	# 余弦相似度
	sourceVec = feature(os.path.abspath("images\\炎兔儿.png"))
	imagePaths = readImagePathList(imagePath)
	for imageName in imagePaths:
		targetVec = feature(imageName)
		print(cosin(sourceVec, targetVec))
		if cosin(sourceVec, targetVec) >= 0.98:
			cv2.imshow("", cv2.imdecode(np.fromfile(imageName, dtype=np.uint8), -1))
			cv2.waitKey()

	# 汉明距离
	callback = avgHash
	sourceHash = callback(os.path.abspath("images\\炎兔儿.png"))
	for imageName in imagePaths:
		targetHash = callback(imageName)
		print(hamming(sourceHash, targetHash))
		if hamming(sourceHash, targetHash) >= 0.98:
			cv2.imshow("", cv2.imdecode(np.fromfile(imageName, dtype=np.uint8), -1))
			cv2.waitKey() 

if "__main__" == __name__:
	main()
	exit(0)